{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45135531",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_io as tfio\n",
    "\n",
    "BATCH_SIZE=64\n",
    "MAX_EPOCHS=5\n",
    "NUM_COLUMNS=784\n",
    "\n",
    "KAFKA_SERVERS=\"kafka:9092\"\n",
    "KAFKA_CONSUMER_GROUP=\"mnistcg\"\n",
    "KAFKA_TRAIN_TOPIC=\"tf.public.mnist_train\"\n",
    "KAFKA_TEST_TOPIC=\"tf.public.mnist_test\"\n",
    "KAFKA_STREAM_TIMEOUT=9000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f600a8d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_kafka_record(record):\n",
    "    img_int = tf.io.decode_csv(record.message, [[0.0] for i in range(NUM_COLUMNS)])\n",
    "    img_norm = tf.cast(img_int, tf.float32) / 255.\n",
    "    label_int = tf.strings.to_number(record.key, out_type=tf.dtypes.int32)\n",
    "    return (img_norm, label_int)\n",
    "\n",
    "train_ds = tfio.IODataset.from_kafka(KAFKA_TRAIN_TOPIC, partition=0, offset=0, servers=KAFKA_SERVERS)\n",
    "train_ds = train_ds.map(decode_kafka_record)\n",
    "train_ds = train_ds.batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7718b20",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Dense(128, activation='relu'),\n",
    "    tf.keras.layers.Dense(10)\n",
    "])\n",
    "\n",
    "model.compile(\n",
    "    optimizer=tf.keras.optimizers.Adam(0.001),\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n",
    ")\n",
    "\n",
    "model.fit(train_ds,epochs=MAX_EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8e2b080",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_kafka_stream_record(message, key):\n",
    "    img_int = tf.io.decode_csv(message, [[0.0] for i in range(NUM_COLUMNS)])\n",
    "    img_norm = tf.cast(img_int, tf.float32) / 255.\n",
    "    label_int = tf.strings.to_number(key, out_type=tf.dtypes.int32)\n",
    "    return (img_norm, label_int)\n",
    "\n",
    "test_ds = tfio.experimental.streaming.KafkaGroupIODataset(\n",
    "    topics=[KAFKA_TEST_TOPIC],\n",
    "    group_id=KAFKA_CONSUMER_GROUP,\n",
    "    servers=KAFKA_SERVERS,\n",
    "    stream_timeout=KAFKA_STREAM_TIMEOUT,\n",
    "    configuration=[\n",
    "        \"session.timeout.ms=10000\",\n",
    "        \"max.poll.interval.ms=10000\",\n",
    "        \"auto.offset.reset=earliest\"\n",
    "    ],\n",
    ")\n",
    "\n",
    "test_ds = test_ds.map(decode_kafka_stream_record)\n",
    "test_ds = test_ds.batch(BATCH_SIZE)\n",
    "\n",
    "model.evaluate(test_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1b33fe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_and_predict(pixels):\n",
    "    test = tf.constant([pixels])\n",
    "    tf.shape(test)\n",
    "    test_norm = tf.cast(test, tf.float32) / 255.\n",
    "\n",
    "    prediction = model.predict(test_norm)\n",
    "    number = tf.nn.softmax(prediction).numpy().argmax()\n",
    "    \n",
    "    pixels_array = np.asarray(pixels)\n",
    "    raw_img = np.split(pixels_array, 28)\n",
    "    plt.imshow(raw_img)\n",
    "    plt.title(number)\n",
    "    plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d4e0c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "pixels = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,108,43,6,6,6,6,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10,84,248,254,254,254,254,254,241,45,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,90,254,254,254,223,173,173,173,253,156,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,79,157,228,245,251,188,63,17,0,0,54,252,132,0,0,0,0,0,0,0,0,0,0,0,0,0,0,32,254,254,254,244,131,0,0,0,0,13,220,254,122,0,0,0,0,0,0,0,0,0,0,0,0,0,0,83,254,225,160,47,0,0,0,0,59,211,254,206,50,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,21,14,0,0,0,2,17,146,245,250,194,12,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,140,140,171,254,254,254,203,55,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,211,254,254,254,254,179,211,254,254,202,171,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,167,233,193,69,16,3,9,16,107,231,248,195,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,73,229,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,26,99,252,254,146,0,0,0,0,0,0,0,0,79,142,0,0,0,0,0,0,0,0,0,26,28,116,147,247,254,239,150,22,0,0,0,0,0,0,0,0,175,230,174,155,66,66,132,174,174,174,174,250,255,254,192,189,99,36,0,0,0,0,0,0,0,0,0,0,106,226,254,254,254,254,254,254,254,254,217,151,80,43,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4,7,114,114,114,46,5,5,5,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]\n",
    "plot_and_predict(pixels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8be8586d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pixels = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,191,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,255,255,255,0,64,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,255,255,255,255,255,255,255,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,255,255,255,191,0,191,255,255,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,255,128,255,255,255,255,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,255,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,255,255,255,255,255,128,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,255,255,0,128,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,255,64,0,64,191,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,128,0,0,0,64,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,255,255,255,0,0,0,0,0,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,191,255,255,64,0,0,0,0,255,255,255,191,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,128,255,255,191,128,191,128,191,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,255,255,255,255,255,255,255,255,255,255,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,255,255,255,255,255,255,255,255,128,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,64,191,255,255,255,255,255,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]\n",
    "plot_and_predict(pixels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2261ad58",
   "metadata": {},
   "outputs": [],
   "source": [
    "manual_ds = tfio.experimental.streaming.KafkaGroupIODataset(\n",
    "    topics=[KAFKA_TEST_TOPIC],\n",
    "    group_id=\"mnistcg2\",\n",
    "    servers=KAFKA_SERVERS,\n",
    "    stream_timeout=KAFKA_STREAM_TIMEOUT,\n",
    "    configuration=[\n",
    "        \"session.timeout.ms=10000\",\n",
    "        \"max.poll.interval.ms=10000\",\n",
    "        \"auto.offset.reset=earliest\"\n",
    "    ],\n",
    ")\n",
    "\n",
    "manual_ds = manual_ds.map(decode_kafka_stream_record)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51561ade",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(pixels):\n",
    "    test = tf.constant([pixels])\n",
    "    tf.shape(test)\n",
    "    prediction = model.predict(test)\n",
    "    return tf.nn.softmax(prediction).numpy().argmax()\n",
    "\n",
    "\n",
    "iterator = iter(manual_ds)\n",
    "figure, axis = plt.subplots(2, 2)\n",
    "for i in range(0,2):\n",
    "    for j in range(0,2):\n",
    "        record = iterator.get_next()\n",
    "        image = record[0].numpy()\n",
    "        axis[i, j].imshow(np.split(np.asarray(image), 28))\n",
    "        axis[i, j].set_title(predict(image))\n",
    "        axis[i, j].axes.get_xaxis().set_visible(False)\n",
    "        axis[i, j].axes.get_yaxis().set_visible(False)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
